import torch
input_tensor = torch.tensor([[1, 2, 3],
                             [4, 5, 6],
                             [7, 8, 9]])

# 索引张量，用于选择 input_tensor 中的元素
index_tensor = torch.tensor([[0], [2], [1]])

output_tensor = torch.gather(input_tensor,dim=1,index=index_tensor)

print(output_tensor)

